{
  "cells": [
    {
      "cell_type": "code",
      "metadata": {},
      "source": [
        "# code by Tae Hwan Jung(Jeff Jung) @graykode\n",
        "# Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch\n",
        "#           https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert\n",
        "import math\n",
        "import re\n",
        "from random import *\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "\n",
        "# sample IsNext and NotNext to be same in small batch size\n",
        "def make_batch():\n",
        "    batch = []\n",
        "    positive = negative = 0\n",
        "    while positive != batch_size/2 or negative != batch_size/2:\n",
        "        tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences\n",
        "        tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]\n",
        "        input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]\n",
        "        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)\n",
        "\n",
        "        # MASK LM\n",
        "        n_pred =  min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence\n",
        "        cand_maked_pos = [i for i, token in enumerate(input_ids)\n",
        "                          if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]\n",
        "        shuffle(cand_maked_pos)\n",
        "        masked_tokens, masked_pos = [], []\n",
        "        for pos in cand_maked_pos[:n_pred]:\n",
        "            masked_pos.append(pos)\n",
        "            masked_tokens.append(input_ids[pos])\n",
        "            if random() < 0.8:  # 80%\n",
        "                input_ids[pos] = word_dict['[MASK]'] # make mask\n",
        "            elif random() < 0.5:  # 10%\n",
        "                index = randint(0, vocab_size - 1) # random index in vocabulary\n",
        "                input_ids[pos] = word_dict[number_dict[index]] # replace\n",
        "\n",
        "        # Zero Paddings\n",
        "        n_pad = maxlen - len(input_ids)\n",
        "        input_ids.extend([0] * n_pad)\n",
        "        segment_ids.extend([0] * n_pad)\n",
        "\n",
        "        # Zero Padding (100% - 15%) tokens\n",
        "        if max_pred > n_pred:\n",
        "            n_pad = max_pred - n_pred\n",
        "            masked_tokens.extend([0] * n_pad)\n",
        "            masked_pos.extend([0] * n_pad)\n",
        "\n",
        "        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:\n",
        "            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext\n",
        "            positive += 1\n",
        "        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:\n",
        "            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext\n",
        "            negative += 1\n",
        "    return batch\n",
        "# Proprecessing Finished\n",
        "\n",
        "def get_attn_pad_mask(seq_q, seq_k):\n",
        "    batch_size, len_q = seq_q.size()\n",
        "    batch_size, len_k = seq_k.size()\n",
        "    # eq(zero) is PAD token\n",
        "    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking\n",
        "    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k\n",
        "\n",
        "def gelu(x):\n",
        "    \"Implementation of the gelu activation function by Hugging Face\"\n",
        "    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))\n",
        "\n",
        "class Embedding(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Embedding, self).__init__()\n",
        "        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding\n",
        "        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding\n",
        "        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding\n",
        "        self.norm = nn.LayerNorm(d_model)\n",
        "\n",
        "    def forward(self, x, seg):\n",
        "        seq_len = x.size(1)\n",
        "        pos = torch.arange(seq_len, dtype=torch.long)\n",
        "        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (batch_size, seq_len)\n",
        "        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
        "        return self.norm(embedding)\n",
        "\n",
        "class ScaledDotProductAttention(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(ScaledDotProductAttention, self).__init__()\n",
        "\n",
        "    def forward(self, Q, K, V, attn_mask):\n",
        "        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n",
        "        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.\n",
        "        attn = nn.Softmax(dim=-1)(scores)\n",
        "        context = torch.matmul(attn, V)\n",
        "        return context, attn\n",
        "\n",
        "class MultiHeadAttention(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(MultiHeadAttention, self).__init__()\n",
        "        self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
        "        self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
        "        self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
        "    def forward(self, Q, K, V, attn_mask):\n",
        "        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]\n",
        "        residual, batch_size = Q, Q.size(0)\n",
        "        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)\n",
        "        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]\n",
        "        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]\n",
        "        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]\n",
        "\n",
        "        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]\n",
        "\n",
        "        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n",
        "        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)\n",
        "        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]\n",
        "        output = nn.Linear(n_heads * d_v, d_model)(context)\n",
        "        return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]\n",
        "\n",
        "class PoswiseFeedForwardNet(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(PoswiseFeedForwardNet, self).__init__()\n",
        "        self.fc1 = nn.Linear(d_model, d_ff)\n",
        "        self.fc2 = nn.Linear(d_ff, d_model)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)\n",
        "        return self.fc2(gelu(self.fc1(x)))\n",
        "\n",
        "class EncoderLayer(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(EncoderLayer, self).__init__()\n",
        "        self.enc_self_attn = MultiHeadAttention()\n",
        "        self.pos_ffn = PoswiseFeedForwardNet()\n",
        "\n",
        "    def forward(self, enc_inputs, enc_self_attn_mask):\n",
        "        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V\n",
        "        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]\n",
        "        return enc_outputs, attn\n",
        "\n",
        "class BERT(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(BERT, self).__init__()\n",
        "        self.embedding = Embedding()\n",
        "        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
        "        self.fc = nn.Linear(d_model, d_model)\n",
        "        self.activ1 = nn.Tanh()\n",
        "        self.linear = nn.Linear(d_model, d_model)\n",
        "        self.activ2 = gelu\n",
        "        self.norm = nn.LayerNorm(d_model)\n",
        "        self.classifier = nn.Linear(d_model, 2)\n",
        "        # decoder is shared with embedding layer\n",
        "        embed_weight = self.embedding.tok_embed.weight\n",
        "        n_vocab, n_dim = embed_weight.size()\n",
        "        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)\n",
        "        self.decoder.weight = embed_weight\n",
        "        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))\n",
        "\n",
        "    def forward(self, input_ids, segment_ids, masked_pos):\n",
        "        output = self.embedding(input_ids, segment_ids)\n",
        "        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)\n",
        "        for layer in self.layers:\n",
        "            output, enc_self_attn = layer(output, enc_self_attn_mask)\n",
        "        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]\n",
        "        # it will be decided by first token(CLS)\n",
        "        h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model]\n",
        "        logits_clsf = self.classifier(h_pooled) # [batch_size, 2]\n",
        "\n",
        "        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]\n",
        "        # get masked position from final output of transformer.\n",
        "        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]\n",
        "        h_masked = self.norm(self.activ2(self.linear(h_masked)))\n",
        "        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]\n",
        "\n",
        "        return logits_lm, logits_clsf\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    # BERT Parameters\n",
        "    maxlen = 30 # maximum of length\n",
        "    batch_size = 6\n",
        "    max_pred = 5  # max tokens of prediction\n",
        "    n_layers = 6 # number of Encoder of Encoder Layer\n",
        "    n_heads = 12 # number of heads in Multi-Head Attention\n",
        "    d_model = 768 # Embedding Size\n",
        "    d_ff = 768 * 4  # 4*d_model, FeedForward dimension\n",
        "    d_k = d_v = 64  # dimension of K(=Q), V\n",
        "    n_segments = 2\n",
        "\n",
        "    text = (\n",
        "        'Hello, how are you? I am Romeo.\\n'\n",
        "        'Hello, Romeo My name is Juliet. Nice to meet you.\\n'\n",
        "        'Nice meet you too. How are you today?\\n'\n",
        "        'Great. My baseball team won the competition.\\n'\n",
        "        'Oh Congratulations, Juliet\\n'\n",
        "        'Thanks you Romeo'\n",
        "    )\n",
        "    sentences = re.sub(\"[.,!?\\\\-]\", '', text.lower()).split('\\n')  # filter '.', ',', '?', '!'\n",
        "    word_list = list(set(\" \".join(sentences).split()))\n",
        "    word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}\n",
        "    for i, w in enumerate(word_list):\n",
        "        word_dict[w] = i + 4\n",
        "    number_dict = {i: w for i, w in enumerate(word_dict)}\n",
        "    vocab_size = len(word_dict)\n",
        "\n",
        "    token_list = list()\n",
        "    for sentence in sentences:\n",
        "        arr = [word_dict[s] for s in sentence.split()]\n",
        "        token_list.append(arr)\n",
        "\n",
        "    model = BERT()\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
        "\n",
        "    batch = make_batch()\n",
        "    input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))\n",
        "\n",
        "    for epoch in range(100):\n",
        "        optimizer.zero_grad()\n",
        "        logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)\n",
        "        loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM\n",
        "        loss_lm = (loss_lm.float()).mean()\n",
        "        loss_clsf = criterion(logits_clsf, isNext) # for sentence classification\n",
        "        loss = loss_lm + loss_clsf\n",
        "        if (epoch + 1) % 10 == 0:\n",
        "            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "    # Predict mask tokens ans isNext\n",
        "    input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0]))\n",
        "    print(text)\n",
        "    print([number_dict[w.item()] for w in input_ids[0] if number_dict[w.item()] != '[PAD]'])\n",
        "\n",
        "    logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)\n",
        "    logits_lm = logits_lm.data.max(2)[1][0].data.numpy()\n",
        "    print('masked tokens list : ',[pos.item() for pos in masked_tokens[0] if pos.item() != 0])\n",
        "    print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])\n",
        "\n",
        "    logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]\n",
        "    print('isNext : ', True if isNext else False)\n",
        "    print('predict isNext : ',True if logits_clsf else False)\n"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "anaconda-cloud": {},
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.1"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}